{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "pgFYFftQKxY5"
   },
   "source": [
    "<p style=\"align: center;\"><img align=center src=\"https://s8.hostingkartinok.com/uploads/images/2018/08/308b49fcfbc619d629fe4604bceb67ac.jpg\" style=\"height:450px;\" width=500/></p>\n",
    "\n",
    "<h3 style=\"text-align: center;\"><b>Школа глубокого обучения ФПМИ МФТИ</b></h3>\n",
    "<h3 style=\"text-align: center;\"><b>Базовый и продвинутый потоки. Осень 2021</b></h3>\n",
    "\n",
    "<h1 style=\"text-align: center;\"><b>Домашнее задание. Библиотека sklearn и классификация с помощью KNN</b></h1>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "v4RCHGZULaWz"
   },
   "source": [
    "На основе [курса по Машинному Обучению ФИВТ МФТИ](https://github.com/ml-mipt/ml-mipt) и [Открытого курса по Машинному Обучению](https://habr.com/ru/company/ods/blog/322626/)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "F2acNQu1L94J"
   },
   "source": [
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Twe_cnn5KxY6"
   },
   "source": [
    "<h2 style=\"text-align: center;\"><b>K Nearest Neighbors (KNN)</b></h2>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "YD0NXyUYKxY7"
   },
   "source": [
    "Метод ближайших соседей (k Nearest Neighbors, или kNN) — очень популярный метод классификации, также иногда используемый в задачах регрессии. Это один из самых понятных подходов к классификации. На уровне интуиции суть метода такова: посмотри на соседей; какие преобладают --- таков и ты. Формально основой метода является гипотеза компактности: если метрика расстояния между примерами введена достаточно удачно, то схожие примеры гораздо чаще лежат в одном классе, чем в разных. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "CTa2jNZkKxY8"
   },
   "source": [
    "<img src='https://hsto.org/web/68d/a45/6f0/68da456f00f8434e87628dbe7e3f54a7.png' width=600>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "5H7wPU0IKxY-"
   },
   "source": [
    "\n",
    "Для классификации каждого из объектов тестовой выборки необходимо последовательно выполнить следующие операции:\n",
    "\n",
    "* Вычислить расстояние до каждого из объектов обучающей выборки\n",
    "* Отобрать объектов обучающей выборки, расстояние до которых минимально\n",
    "* Класс классифицируемого объекта — это класс, наиболее часто встречающийся среди $k$ ближайших соседей"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "T2docs4225pb"
   },
   "source": [
    "Будем работать с подвыборкой из [данных о типе лесного покрытия из репозитория UCI](http://archive.ics.uci.edu/ml/datasets/Covertype). Доступно 7 различных классов. Каждый объект описывается 54 признаками, 40 из которых являются бинарными. Описание данных доступно по ссылке."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "AcjJQX3wKxZA"
   },
   "source": [
    "### Обработка данных"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "Ozcx5mVOKxZB"
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Ry4bMKaUjHJj"
   },
   "source": [
    "Сcылка на датасет (лежит в папке): https://drive.google.com/drive/folders/16TSz1P-oTF8iXSQ1xrt0r_VO35xKmUes?usp=sharing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "rvPrVRvK25pc"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "      <th>5</th>\n",
       "      <th>6</th>\n",
       "      <th>7</th>\n",
       "      <th>8</th>\n",
       "      <th>9</th>\n",
       "      <th>...</th>\n",
       "      <th>45</th>\n",
       "      <th>46</th>\n",
       "      <th>47</th>\n",
       "      <th>48</th>\n",
       "      <th>49</th>\n",
       "      <th>50</th>\n",
       "      <th>51</th>\n",
       "      <th>52</th>\n",
       "      <th>53</th>\n",
       "      <th>54</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2683</td>\n",
       "      <td>333</td>\n",
       "      <td>35</td>\n",
       "      <td>30</td>\n",
       "      <td>26</td>\n",
       "      <td>2743</td>\n",
       "      <td>121</td>\n",
       "      <td>173</td>\n",
       "      <td>179</td>\n",
       "      <td>6572</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2915</td>\n",
       "      <td>90</td>\n",
       "      <td>8</td>\n",
       "      <td>216</td>\n",
       "      <td>11</td>\n",
       "      <td>4433</td>\n",
       "      <td>232</td>\n",
       "      <td>228</td>\n",
       "      <td>129</td>\n",
       "      <td>4019</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2941</td>\n",
       "      <td>162</td>\n",
       "      <td>7</td>\n",
       "      <td>698</td>\n",
       "      <td>76</td>\n",
       "      <td>2783</td>\n",
       "      <td>227</td>\n",
       "      <td>242</td>\n",
       "      <td>148</td>\n",
       "      <td>1784</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3096</td>\n",
       "      <td>60</td>\n",
       "      <td>17</td>\n",
       "      <td>170</td>\n",
       "      <td>3</td>\n",
       "      <td>3303</td>\n",
       "      <td>231</td>\n",
       "      <td>202</td>\n",
       "      <td>99</td>\n",
       "      <td>5370</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2999</td>\n",
       "      <td>66</td>\n",
       "      <td>8</td>\n",
       "      <td>488</td>\n",
       "      <td>37</td>\n",
       "      <td>1532</td>\n",
       "      <td>228</td>\n",
       "      <td>225</td>\n",
       "      <td>131</td>\n",
       "      <td>2290</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 55 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      0    1   2    3   4     5    6    7    8     9  ...  45  46  47  48  49  \\\n",
       "0  2683  333  35   30  26  2743  121  173  179  6572  ...   0   0   0   0   0   \n",
       "1  2915   90   8  216  11  4433  232  228  129  4019  ...   0   0   0   0   0   \n",
       "2  2941  162   7  698  76  2783  227  242  148  1784  ...   0   0   0   0   0   \n",
       "3  3096   60  17  170   3  3303  231  202   99  5370  ...   0   0   0   0   0   \n",
       "4  2999   66   8  488  37  1532  228  225  131  2290  ...   0   0   0   0   0   \n",
       "\n",
       "   50  51  52  53  54  \n",
       "0   0   0   0   0   2  \n",
       "1   0   0   0   0   1  \n",
       "2   0   0   0   0   2  \n",
       "3   0   0   0   0   1  \n",
       "4   0   0   0   0   2  \n",
       "\n",
       "[5 rows x 55 columns]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_data = pd.read_csv('forest_dataset.csv')\n",
    "all_data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "_o8yXBPSKxZI"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(10000, 55)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_data.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "itCWxHEY25pg"
   },
   "source": [
    "Выделим значения метки класса в переменную `labels`, признаковые описания --- в переменную `feature_matrix`. Так как данные числовые и не имеют пропусков, переведем их в `numpy`-формат с помощью метода `.values`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "f_YIUOuV25ph"
   },
   "outputs": [],
   "source": [
    "labels = all_data[all_data.columns[-1]].values\n",
    "feature_matrix = all_data[all_data.columns[:-1]].values"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "FukXaH_r8PMQ"
   },
   "source": [
    "### Пара слов о sklearn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "k5S_0Lfc8PMR"
   },
   "source": [
    "**[sklearn](https://scikit-learn.org/stable/index.html)** -- удобная библиотека для знакомства с машинным обучением. В ней реализованны большинство стандартных алгоритмов для построения моделей и работ с выборками. У неё есть подробная документация на английском, с которой вам придётся поработать."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "VhVDEG538PMS"
   },
   "source": [
    "`sklearn` предпологает, что ваши выборки имеют вид пар $(X, y)$, где $X$ -- матрица признаков, $y$ -- вектор истинных значений целевой переменной, или просто $X$, если целевые переменные неизвестны."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "QJZQulsp8PMT"
   },
   "source": [
    "Познакомимся со вспомогательной функцией \n",
    "[train_test_split](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html).\n",
    "С её помощью можно разбить выборку на обучающую и тестовую части."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "Q030jzyY25pl"
   },
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "UkeB47mX8PMY"
   },
   "source": [
    "Вернёмся к датасету. Сейчас будем работать со всеми 7 типами покрытия (данные уже находятся в переменных `feature_matrix` и `labels`, если Вы их не переопределили). Разделим выборку на обучающую и тестовую с помощью метода `train_test_split`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "id": "YJN0jFARKxZX"
   },
   "outputs": [],
   "source": [
    "train_feature_matrix, test_feature_matrix, train_labels, test_labels = train_test_split(\n",
    "    feature_matrix, labels, test_size=0.2, random_state=42)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "odC1c7X48PMb"
   },
   "source": [
    "Параметр `test_size` контролирует, какая часть выборки будет тестовой. Более подробно о нём можно прочитать в [документации](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "z3fGvPqG8PMc"
   },
   "source": [
    "Основные объекты `sklearn` -- так называемые `estimators`, что можно перевести как *оценщики*, но не стоит, так как по сути это *модели*. Они делятся на **классификаторы** и **регрессоры**.\n",
    "\n",
    "В качестве примера модели можно привести классификаторы\n",
    "[метод ближайших соседей](https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html) и \n",
    "[логистическую регрессию](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html). Что такое логистическая регрессия и как она работает сейчас не важно."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "IuX8Rc7c8PMd"
   },
   "source": [
    "У всех моделей в `sklearn` обязательно должно быть хотя бы 2 метода (подробнее о методах и классах в python будет в следующих занятиях) -- `fit` и `predict`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ZYokUkxO8PMe"
   },
   "source": [
    "Метод `fit(X, y)` отвечает за обучение модели и принимает на вход обучающую выборку в виде *матрицы признаков* $X$ и *вектора ответов* $y$.\n",
    "\n",
    "У обученной после `fit` модели теперь можно вызывать метод `predict(X)`, который вернёт предсказания этой модели на всех объектах из матрицы $X$ в виде вектора.\n",
    "\n",
    "Вызывать `fit` у одной и той же модели можно несколько раз, каждый раз она будет обучаться заново на переданном наборе данных.\n",
    "\n",
    "Ещё у моделей есть *гиперпараметры*, которые обычно задаются при создании модели.\n",
    "\n",
    "Рассмотрим всё это на примере логистической регрессии."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "id": "ew0Ji_2D8PMe"
   },
   "outputs": [],
   "source": [
    "from sklearn.linear_model import LogisticRegression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "id": "c9KcMHXr8PMh"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Anaconda\\anaconda\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py:814: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n"
     ]
    }
   ],
   "source": [
    "# создание модели с указанием гиперпараметра C\n",
    "clf = LogisticRegression(C=1)\n",
    "# обучение модели\n",
    "clf.fit(train_feature_matrix, train_labels)\n",
    "# предсказание на тестовой выборке\n",
    "y_pred = clf.predict(test_feature_matrix)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "h3gjg3pm8PMm"
   },
   "source": [
    "Теперь хотелось бы измерить качество нашей модели. Для этого можно использовать метод `score(X, y)`, который посчитает какую-то функцию ошибки на выборке $X, y$, но какую конкретно уже зависит от модели. Также можно использовать одну из функций модуля `metrics`, например [accuracy_score](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html), которая, как понятно из названия, вычислит нам точность предсказаний."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "id": "J2Ej1Lni8PMn"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6075"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.metrics import accuracy_score\n",
    "\n",
    "accuracy_score(test_labels, y_pred)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "malIDW_P8PMp"
   },
   "source": [
    "Наконец, последним, о чём хотелось бы упомянуть, будет перебор гиперпараметров по сетке. Так как у моделей есть много гиперпараметров, которые можно изменять, и от этих гиперпараметров существенно зависит качество модели, хотелось бы найти наилучшие в этом смысле параметры. Самый простой способ это сделать -- просто перебрать все возможные варианты в разумных пределах.\n",
    "\n",
    "Сделать это можно с помощью класса [GridSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html), который осуществляет поиск (search) по сетке (grid) и вычисляет качество модели с помощью кросс-валидации (CV).\n",
    "\n",
    "У логистической регрессии, например, можно поменять параметры `C` и `penalty`. Сделаем это. Учтите, что поиск может занять долгое время. Смысл параметров смотрите в документации."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "id": "vq687Aoc8PMq"
   },
   "outputs": [],
   "source": [
    "from sklearn.model_selection import GridSearchCV"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "id": "OVnqHBvK8PMs"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'C': 2, 'penalty': 'l2'}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Anaconda\\anaconda\\lib\\site-packages\\sklearn\\linear_model\\_sag.py:352: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "# заново создадим модель, указав солвер\n",
    "clf = LogisticRegression(solver='saga')\n",
    "\n",
    "# опишем сетку, по которой будем искать\n",
    "param_grid = {\n",
    "    'C': np.arange(1, 5), # также можно указать обычный массив, [1, 2, 3, 4]\n",
    "    'penalty': ['l1', 'l2'],\n",
    "}\n",
    "\n",
    "# создадим объект GridSearchCV\n",
    "search = GridSearchCV(clf, param_grid, n_jobs=-1, cv=5, refit=True, scoring='accuracy')\n",
    "\n",
    "# запустим поиск\n",
    "search.fit(feature_matrix, labels)\n",
    "\n",
    "# выведем наилучшие параметры\n",
    "print(search.best_params_)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "DnVTFcvZ8PMv"
   },
   "source": [
    "В данном случае, поиск перебирает все возможные пары значений C и penalty из заданных множеств."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "id": "ArKINrE_8PMw"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6418"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "accuracy_score(labels, search.best_estimator_.predict(feature_matrix))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "okzpKY_I8PMz"
   },
   "source": [
    "Заметьте, что мы передаём в GridSearchCV всю выборку, а не только её обучающую часть. Это можно делать, так как поиск всё равно использует кроссвалидацию. Однако порой от выборки всё-же отделяют *валидационную* часть, так как гиперпараметры в процессе поиска могли переобучиться под выборку."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_mdJyxdo8PM1"
   },
   "source": [
    "В заданиях вам предстоит повторить это для метода ближайших соседей."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "z8W__017KxZc"
   },
   "source": [
    "### Обучение модели"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "02uT6CPYKxZe"
   },
   "source": [
    "Качество классификации/регрессии методом ближайших соседей зависит от нескольких параметров:\n",
    "\n",
    "* число соседей `n_neighbors`\n",
    "* метрика расстояния между объектами `metric`\n",
    "* веса соседей (соседи тестового примера могут входить с разными весами, например, чем дальше пример, тем с меньшим коэффициентом учитывается его \"голос\") `weights`\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "BHVNCaJ325qD"
   },
   "source": [
    "Обучите на датасете `KNeighborsClassifier` из `sklearn`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "id": "o4CMnnOY25qD"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Anaconda\\anaconda\\lib\\site-packages\\sklearn\\neighbors\\_classification.py:228: FutureWarning: Unlike other reduction functions (e.g. `skew`, `kurtosis`), the default behavior of `mode` typically preserves the axis it acts along. In SciPy 1.11.0, this behavior will change: the default value of `keepdims` will become False, the `axis` over which the statistic is taken will be eliminated, and the value None will no longer be accepted. Set `keepdims` to True or False to avoid this warning.\n",
      "  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.7365"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "from sklearn.metrics import accuracy_score\n",
    "\n",
    "clf = KNeighborsClassifier()\n",
    "clf.fit(train_feature_matrix, train_labels)\n",
    "pred_labels = clf.predict(test_feature_matrix)\n",
    "accuracy_score(test_labels, pred_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "r_2Mf8BiKxZk"
   },
   "source": [
    "### Вопрос 1:\n",
    "* Какое качество у вас получилось?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "uFTIaPdrKxZl"
   },
   "source": [
    "Подберём параметры нашей модели"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "8WzoRJZd25qF"
   },
   "source": [
    "* Переберите по сетке от `1` до `10` параметр числа соседей\n",
    "\n",
    "* Также вы попробуйте использоввать различные метрики: `['manhattan', 'euclidean']`\n",
    "\n",
    "* Попробуйте использовать различные стратегии вычисления весов: `[‘uniform’, ‘distance’]`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "id": "4lMSy-6f25qG",
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "GridSearchCV(cv=5, estimator=KNeighborsClassifier(), n_jobs=-1,\n",
       "             param_grid={'metric': ['manhattan', 'euclidean'],\n",
       "                         'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],\n",
       "                         'weights': ['uniform', 'distance']},\n",
       "             scoring='accuracy')"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.model_selection import GridSearchCV\n",
    "params = {'weights': [\"uniform\", \"distance\"], 'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 'metric': [\"manhattan\", \"euclidean\"]}\n",
    "\n",
    "clf_grid = GridSearchCV(clf, params, cv=5, scoring='accuracy', n_jobs=-1)\n",
    "clf_grid.fit(train_feature_matrix, train_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SO7E6G8jKxZp"
   },
   "source": [
    "Выведем лучшие параметры"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "id": "md48pHrMKxZq"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'metric': 'manhattan', 'n_neighbors': 4, 'weights': 'distance'}"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clf_grid.best_params_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "M05n9l8pKxZt"
   },
   "source": [
    "### Вопрос 2:\n",
    "* Какую metric следует использовать?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Pmjx38OoKxZt"
   },
   "source": [
    "### Вопрос 3:\n",
    "* Сколько n_neighbors следует использовать?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "eqLeJUP8KxZu"
   },
   "source": [
    "### Вопрос 4:\n",
    "* Какой тип weights следует использовать?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "aBmiDbvV25qI"
   },
   "source": [
    "Используя найденное оптимальное число соседей, вычислите вероятности принадлежности к классам для тестовой выборки (`.predict_proba`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "id": "ig_vS8O925qI"
   },
   "outputs": [],
   "source": [
    "optimal_clf = KNeighborsClassifier(n_neighbors = 10)\n",
    "optimal_clf.fit(train_feature_matrix, train_labels)\n",
    "pred_prob = optimal_clf.predict_proba(test_feature_matrix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "id": "2kkapT38KxZz"
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAzoAAAKTCAYAAADR1X0mAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAsvElEQVR4nO3dfXBV9ZnA8SckEhAlKJgAYwSqFJD4gomrweK7WGSZ7datVC3oClVGUDHrtFLbLXVcwWo1ui60WCsy9YXpoq47vsZWfKNayZDWVcZSlYZiKAXXRLENNbn7hzVrTIAkIDf8+Hxmzgz35Jx7n3iGDt/+7j03J5PJZAIAACAhPbI9AAAAwK4mdAAAgOQIHQAAIDlCBwAASI7QAQAAkiN0AACA5AgdAAAgOXnZHqAjmpub4+233479998/cnJysj0OAACQJZlMJt57770YPHhw9Oix7XWbPSJ03n777SguLs72GAAAQDexbt26OPjgg7f58z0idPbff/+I+OiX6du3b5anAQAAsqWhoSGKi4tbGmFb9ojQ+fjtan379hU6AADADj/S4mYEAABAcoQOAACQHKEDAAAkZ4/4jA4AAGxLU1NT/PWvf832GOwi++yzT+Tm5u708wgdAAD2SJlMJjZs2BDvvvtutkdhF+vXr18MHDhwp75DU+gAALBH+jhyCgsLY9999/XF8gnIZDLxwQcfxMaNGyMiYtCgQV1+LqEDAMAep6mpqSVy+vfvn+1x2IV69+4dEREbN26MwsLCLr+Nzc0IAADY43z8mZx99903y5PwWfj4uu7MZ6+EDgAAeyxvV0vTrriuQgcAAEiO0AEAgEQNHTo0KisrWx7n5OTEQw89tFPPuSueY3dwMwIAAJIy9OpHdttrrZ0/cbe91q5QV1cXBxxwQIeOnTt3bjz00ENRU1PT5efIJqEDAADd2NatW6Nnz5675LkGDhzYLZ5jd/DWNQAA2I1OPvnkmDVrVsyaNSv69esX/fv3j29/+9uRyWQi4qO3m1133XVx4YUXRkFBQXz961+PiIgVK1bEiSeeGL17947i4uK4/PLLY8uWLS3Pu3Hjxpg0aVL07t07hg0bFvfcc0+b1/70287+8Ic/xFe/+tU48MADo0+fPlFWVhYvvfRSLF68OL73ve/Fr3/968jJyYmcnJxYvHhxu8/xyiuvxKmnnhq9e/eO/v37x8UXXxzvv/9+y88vvPDC+NKXvhQ33XRTDBo0KPr37x8zZ87cqTuqdYTQAQCA3ezuu++OvLy8eOmll+K2226LW265JX784x+3/PzGG2+MkpKSqK6uju985zvxyiuvxJlnnhlf/vKX4ze/+U0sXbo0nn/++Zg1a1bLORdeeGGsXbs2fvGLX8R//ud/xoIFC1q+eLM977//fpx00knx9ttvx8MPPxy//vWv4xvf+EY0NzfH5MmT41/+5V9i9OjRUVdXF3V1dTF58uQ2z/HBBx/EF7/4xTjggAPi5Zdfjp/97Gfx1FNPtZorIuLpp5+ON954I55++um4++67Y/HixS3h9Fnx1jUAANjNiouL45ZbbomcnJwYMWJEvPLKK3HLLbe0rN6ceuqpcdVVV7UcP3Xq1DjvvPNi9uzZERExfPjwuO222+Kkk06KhQsXRm1tbTz22GPx4osvxnHHHRcREXfeeWeMGjVqmzPce++98ac//SlefvnlOPDAAyMi4rDDDmv5+X777Rd5eXnbfavaPffcE3/+859jyZIl0adPn4iIuP3222PSpElxww03RFFRUUREHHDAAXH77bdHbm5ujBw5MiZOnBg///nPW37fz4IVHQAA2M2OP/74Vt8VU15eHmvWrImmpqaIiCgrK2t1fHV1dSxevDj222+/lu3MM8+M5ubmeOutt2L16tWRl5fX6ryRI0dGv379tjlDTU1NjBkzpiVyumL16tVx1FFHtURORMQJJ5wQzc3N8frrr7fsGz16dOTm5rY8HjRo0HZXm3YFKzoAANDNfDIcIiKam5vjkksuicsvv7zNsYccckhLVHTmizZ79+69c0NGRCaT2eZrfnL/Pvvs0+Znzc3NO/3622NFBwAAdrMXX3yxzePhw4e3WvX4pGOOOSZeffXVOOyww9psPXv2jFGjRsWHH34YK1eubDnn9ddfj3fffXebMxx55JFRU1MT77zzTrs/79mzZ8sK07YcfvjhUVNT0+qmCC+88EL06NEjPv/5z2/33M+a0AEAgN1s3bp1UVFREa+//nrcd9998e///u9xxRVXbPP4b37zm/HLX/4yZs6cGTU1NbFmzZp4+OGH47LLLouIiBEjRsQXv/jF+PrXvx4vvfRSVFdXx/Tp07e7anPuuefGwIED40tf+lK88MIL8eabb8ayZcvil7/8ZUR8dPe3t956K2pqamLTpk3R2NjY5jnOP//86NWrV1xwwQXxP//zP/H000/HZZddFlOmTGn5fE62CB0AANjNpk6dGn/+85/j7/7u72LmzJlx2WWXxcUXX7zN44888sh45plnYs2aNTFu3LgYM2ZMfOc734lBgwa1HHPXXXdFcXFxnHTSSfHlL385Lr744igsLNzmc/bs2TOefPLJKCwsjLPOOiuOOOKImD9/fsuq0tlnnx1f/OIX45RTTomDDjoo7rvvvjbPse+++8YTTzwR77zzThx77LHxT//0T3HaaafF7bffvhP/dXaNnMzHN+zuxhoaGqKgoCDq6+ujb9++2R4HAIAs+8tf/hJvvfVWDBs2LHr16pXtcTrl5JNPjqOPPjoqKyuzPUq3tb3r29E2sKIDAAAkR+gAAADJcXtpAADYjZYvX57tEfYKVnQAAIDkCB0AACA5QgcAAEiO0AEAAJIjdAAAgOQIHQAAIDlCBwAA9nBr166NnJycqKmpyfYo3Ybv0QEAIC1zC3bja9XvvteiU6zoAABAFm3dujXbIyRJ6AAAwG508sknx6xZs6KioiIGDBgQZ5xxRrz22mtx1llnxX777RdFRUUxZcqU2LRpU8s5jz/+eHzhC1+Ifv36Rf/+/ePv//7v44033sjib9H9CR0AANjN7r777sjLy4sXXngh5s+fHyeddFIcffTRsXLlynj88cfjj3/8Y5xzzjktx2/ZsiUqKiri5Zdfjp///OfRo0eP+Md//Mdobm7O4m/RvfmMDgAA7GaHHXZYfP/734+IiH/913+NY445Jq6//vqWn//kJz+J4uLi+O1vfxuf//zn4+yzz251/p133hmFhYXx2muvRUlJyW6dfU9hRQcAAHazsrKylj9XV1fH008/Hfvtt1/LNnLkyIiIlrenvfHGG3HeeefF5z73uejbt28MGzYsIiJqa2t3//B7CCs6AACwm/Xp06flz83NzTFp0qS44YYb2hw3aNCgiIiYNGlSFBcXxx133BGDBw+O5ubmKCkpcSOD7ehS6CxYsCBuvPHGqKuri9GjR0dlZWWMGzeu3WOXL18ep5xySpv9q1evbilV2NsMvfqRbI/QIWvnT8z2CACQvGOOOSaWLVsWQ4cOjby8tv8837x5c6xevTp+9KMftfyb+/nnn9/dY+5xOv3WtaVLl8bs2bPjmmuuiVWrVsW4ceNiwoQJO1w2e/3116Ourq5lGz58eJeHBgCAVMycOTPeeeedOPfcc+NXv/pVvPnmm/Hkk0/GRRddFE1NTXHAAQdE//79Y9GiRfG73/0ufvGLX0RFRUW2x+72Oh06N998c0ybNi2mT58eo0aNisrKyiguLo6FCxdu97zCwsIYOHBgy5abm9vloQEAIBWDBw+OF154IZqamuLMM8+MkpKSuOKKK6KgoCB69OgRPXr0iPvvvz+qq6ujpKQkrrzyyrjxxhuzPXa316m3rm3dujWqq6vj6quvbrV//PjxsWLFiu2eO2bMmPjLX/4Shx9+eHz7299u9+1sH2tsbIzGxsaWxw0NDZ0ZEwCAvdnc+mxPsF3Lly9vs2/48OHxwAMPbPOc008/PV577bVW+zKZTMufhw4d2uoxnVzR2bRpUzQ1NUVRUVGr/UVFRbFhw4Z2zxk0aFAsWrQoli1bFg888ECMGDEiTjvttHj22We3+Trz5s2LgoKClq24uLgzYwIAAHu5Lt2MICcnp9XjTCbTZt/HRowYESNGjGh5XF5eHuvWrYubbropTjzxxHbPmTNnTqv3HTY0NIgdAACgwzq1ojNgwIDIzc1ts3qzcePGNqs823P88cfHmjVrtvnz/Pz86Nu3b6sNAACgozoVOj179ozS0tKoqqpqtb+qqirGjh3b4edZtWpVyz3BAQAAdrVOv3WtoqIipkyZEmVlZVFeXh6LFi2K2tramDFjRkR89Laz9evXx5IlSyIiorKyMoYOHRqjR4+OrVu3xk9/+tNYtmxZLFu2bNf+JgAA7HV8AD9Nu+K6djp0Jk+eHJs3b45rr7026urqoqSkJB599NEYMmRIRETU1dW1+k6drVu3xlVXXRXr16+P3r17x+jRo+ORRx6Js846a6eHBwBg77TPPvtERMQHH3wQvXv3zvI07GoffPBBRPz/de6KnMwekMENDQ1RUFAQ9fX1Pq9DEoZe/Ui2R+iQtfMnZnsEANimurq6ePfdd6OwsDD23Xffbd4ciz1HJpOJDz74IDZu3Bj9+vVr9+MuHW2DLt11DQAAsm3gwIER8dGNsUhLv379Wq5vVwkdAAD2SDk5OTFo0KAoLCyMv/71r9keh11kn332idzc3J1+HqEDAMAeLTc3d5f8w5i0dOr20gAAAHsCoQMAACRH6AAAAMkROgAAQHKEDgAAkBx3XSM9cwuyPUEH3JvtAQAAkmZFBwAASI7QAQAAkiN0AACA5AgdAAAgOUIHAABIjtABAACSI3QAAIDkCB0AACA5QgcAAEiO0AEAAJIjdAAAgOQIHQAAIDlCBwAASI7QAQAAkiN0AACA5AgdAAAgOUIHAABIjtABAACSI3QAAIDkCB0AACA5QgcAAEiO0AEAAJIjdAAAgOQIHQAAIDlCBwAASI7QAQAAkiN0AACA5AgdAAAgOUIHAABIjtABAACSI3QAAIDkCB0AACA5QgcAAEiO0AEAAJIjdAAAgOQIHQAAIDlCBwAASI7QAQAAkiN0AACA5AgdAAAgOUIHAABIjtABAACSI3QAAIDkCB0AACA5QgcAAEiO0AEAAJIjdAAAgOQIHQAAIDlCBwAASI7QAQAAkiN0AACA5AgdAAAgOUIHAABIjtABAACSI3QAAIDkCB0AACA5QgcAAEiO0AEAAJIjdAAAgOQIHQAAIDlCBwAASI7QAQAAkiN0AACA5AgdAAAgOUIHAABIjtABAACSI3QAAIDkCB0AACA5QgcAAEiO0AEAAJIjdAAAgOQIHQAAIDlCBwAASI7QAQAAkiN0AACA5AgdAAAgOUIHAABIjtABAACSI3QAAIDkCB0AACA5QgcAAEiO0AEAAJIjdAAAgOQIHQAAIDlCBwAASI7QAQAAkiN0AACA5AgdAAAgOUIHAABIjtABAACSI3QAAIDkCB0AACA5QgcAAEiO0AEAAJLTpdBZsGBBDBs2LHr16hWlpaXx3HPPdei8F154IfLy8uLoo4/uyssCAAB0SKdDZ+nSpTF79uy45pprYtWqVTFu3LiYMGFC1NbWbve8+vr6mDp1apx22mldHhYAAKAjOh06N998c0ybNi2mT58eo0aNisrKyiguLo6FCxdu97xLLrkkzjvvvCgvL+/ysAAAAB3RqdDZunVrVFdXx/jx41vtHz9+fKxYsWKb5911113xxhtvxHe/+90OvU5jY2M0NDS02gAAADqqU6GzadOmaGpqiqKiolb7i4qKYsOGDe2es2bNmrj66qvjnnvuiby8vA69zrx586KgoKBlKy4u7syYAADAXq5LNyPIyclp9TiTybTZFxHR1NQU5513Xnzve9+Lz3/+8x1+/jlz5kR9fX3Ltm7duq6MCQAA7KU6tsTyNwMGDIjc3Nw2qzcbN25ss8oTEfHee+/FypUrY9WqVTFr1qyIiGhubo5MJhN5eXnx5JNPxqmnntrmvPz8/MjPz+/MaAAAAC06taLTs2fPKC0tjaqqqlb7q6qqYuzYsW2O79u3b7zyyitRU1PTss2YMSNGjBgRNTU1cdxxx+3c9AAAAO3o1IpORERFRUVMmTIlysrKory8PBYtWhS1tbUxY8aMiPjobWfr16+PJUuWRI8ePaKkpKTV+YWFhdGrV682+wEAAHaVTofO5MmTY/PmzXHttddGXV1dlJSUxKOPPhpDhgyJiIi6urodfqcOAADAZyknk8lksj3EjjQ0NERBQUHU19dH3759sz0O3d3cgmxPsEND/3JvtkfokLXzJ2Z7BACAVjraBl266xoAAEB3JnQAAIDkCB0AACA5QgcAAEiO0AEAAJIjdAAAgOQIHQAAIDlCBwAASI7QAQAAkiN0AACA5AgdAAAgOUIHAABIjtABAACSI3QAAIDkCB0AACA5QgcAAEiO0AEAAJIjdAAAgOQIHQAAIDlCBwAASI7QAQAAkiN0AACA5AgdAAAgOUIHAABIjtABAACSk5ftAfjsDL36kWyPsENr50/M9ggAACTIig4AAJAcoQMAACRH6AAAAMnxGZ2umFuQ7Qk66N5sDwAAAFlhRQcAAEiO0AEAAJIjdAAAgOQIHQAAIDlCBwAASI7QAQAAkiN0AACA5AgdAAAgOUIHAABIjtABAACSI3QAAIDkCB0AACA5QgcAAEiO0AEAAJIjdAAAgOQIHQAAIDlCBwAASI7QAQAAkiN0AACA5AgdAAAgOUIHAABIjtABAACSI3QAAIDkCB0AACA5QgcAAEiO0AEAAJIjdAAAgOQIHQAAIDlCBwAASI7QAQAAkiN0AACA5AgdAAAgOUIHAABIjtABAACSI3QAAIDkCB0AACA5QgcAAEiO0AEAAJIjdAAAgOQIHQAAIDlCBwAASI7QAQAAkiN0AACA5AgdAAAgOUIHAABIjtABAACSI3QAAIDkCB0AACA5QgcAAEiO0AEAAJIjdAAAgOQIHQAAIDlCBwAASI7QAQAAkiN0AACA5AgdAAAgOUIHAABIjtABAACSI3QAAIDkCB0AACA5QgcAAEiO0AEAAJIjdAAAgOQIHQAAIDlCBwAASI7QAQAAkiN0AACA5AgdAAAgOUIHAABIjtABAACSI3QAAIDkCB0AACA5QgcAAEhOl0JnwYIFMWzYsOjVq1eUlpbGc889t81jn3/++TjhhBOif//+0bt37xg5cmTccsstXR4YAABgR/I6e8LSpUtj9uzZsWDBgjjhhBPiRz/6UUyYMCFee+21OOSQQ9oc36dPn5g1a1YceeSR0adPn3j++efjkksuiT59+sTFF1+8S34JAACAT+r0is7NN98c06ZNi+nTp8eoUaOisrIyiouLY+HChe0eP2bMmDj33HNj9OjRMXTo0Pja174WZ5555nZXgQAAAHZGp0Jn69atUV1dHePHj2+1f/z48bFixYoOPceqVatixYoVcdJJJ23zmMbGxmhoaGi1AQAAdFSnQmfTpk3R1NQURUVFrfYXFRXFhg0btnvuwQcfHPn5+VFWVhYzZ86M6dOnb/PYefPmRUFBQctWXFzcmTEBAIC9XJduRpCTk9PqcSaTabPv05577rlYuXJl/PCHP4zKysq47777tnnsnDlzor6+vmVbt25dV8YEAAD2Up26GcGAAQMiNze3zerNxo0b26zyfNqwYcMiIuKII46IP/7xjzF37tw499xz2z02Pz8/8vPzOzMaAABAi06t6PTs2TNKS0ujqqqq1f6qqqoYO3Zsh58nk8lEY2NjZ14aAACgwzp9e+mKioqYMmVKlJWVRXl5eSxatChqa2tjxowZEfHR287Wr18fS5YsiYiI//iP/4hDDjkkRo4cGREffa/OTTfdFJdddtku/DUAAAD+X6dDZ/LkybF58+a49tpro66uLkpKSuLRRx+NIUOGREREXV1d1NbWthzf3Nwcc+bMibfeeivy8vLi0EMPjfnz58cll1yy634LAACAT8jJZDKZbA+xIw0NDVFQUBD19fXRt2/fbI8TMbcg2xN0yNC/3JvtEXZo7fyJu/5J94Drsydcm4jP6PoAAOyEjrZBl+66BgAA0J0JHQAAIDlCBwAASI7QAQAAkiN0AACA5AgdAAAgOUIHAABIjtABAACSI3QAAIDkCB0AACA5QgcAAEiO0AEAAJIjdAAAgOQIHQAAIDlCBwAASI7QAQAAkiN0AACA5AgdAAAgOUIHAABIjtABAACSI3QAAIDkCB0AACA5QgcAAEiO0AEAAJIjdAAAgOQIHQAAIDlCBwAASI7QAQAAkiN0AACA5AgdAAAgOUIHAABIjtABAACSI3QAAIDkCB0AACA5QgcAAEiO0AEAAJIjdAAAgOQIHQAAIDlCBwAASI7QAQAAkiN0AACA5AgdAAAgOUIHAABIjtABAACSI3QAAIDkCB0AACA5QgcAAEiO0AEAAJIjdAAAgOQIHQAAIDlCBwAASI7QAQAAkiN0AACA5AgdAAAgOUIHAABIjtABAACSI3QAAIDkCB0AACA5QgcAAEiO0AEAAJIjdAAAgOQIHQAAIDlCBwAASI7QAQAAkiN0AACA5AgdAAAgOUIHAABIjtABAACSI3QAAIDkCB0AACA5QgcAAEiO0AEAAJIjdAAAgOQIHQAAIDlCBwAASI7QAQAAkiN0AACA5AgdAAAgOUIHAABIjtABAACSI3QAAIDkCB0AACA5QgcAAEiO0AEAAJIjdAAAgOQIHQAAIDlCBwAASI7QAQAAkiN0AACA5AgdAAAgOUIHAABIjtABAACSI3QAAIDkCB0AACA5QgcAAEiO0AEAAJIjdAAAgOQIHQAAIDlCBwAASI7QAQAAkiN0AACA5AgdAAAgOUIHAABITpdCZ8GCBTFs2LDo1atXlJaWxnPPPbfNYx944IE444wz4qCDDoq+fftGeXl5PPHEE10eGAAAYEc6HTpLly6N2bNnxzXXXBOrVq2KcePGxYQJE6K2trbd45999tk444wz4tFHH43q6uo45ZRTYtKkSbFq1aqdHh4AAKA9nQ6dm2++OaZNmxbTp0+PUaNGRWVlZRQXF8fChQvbPb6ysjK+8Y1vxLHHHhvDhw+P66+/PoYPHx7//d//vdPDAwAAtKdTobN169aorq6O8ePHt9o/fvz4WLFiRYeeo7m5Od5777048MADt3lMY2NjNDQ0tNoAAAA6qlOhs2nTpmhqaoqioqJW+4uKimLDhg0deo4f/OAHsWXLljjnnHO2ecy8efOioKCgZSsuLu7MmAAAwF6uSzcjyMnJafU4k8m02dee++67L+bOnRtLly6NwsLCbR43Z86cqK+vb9nWrVvXlTEBAIC9VF5nDh4wYEDk5ua2Wb3ZuHFjm1WeT1u6dGlMmzYtfvazn8Xpp5++3WPz8/MjPz+/M6MBAAC06NSKTs+ePaO0tDSqqqpa7a+qqoqxY8du87z77rsvLrzwwrj33ntj4sSJXZsUAACggzq1ohMRUVFREVOmTImysrIoLy+PRYsWRW1tbcyYMSMiPnrb2fr162PJkiUR8VHkTJ06NW699dY4/vjjW1aDevfuHQUFBbvwVwEAAPhIp0Nn8uTJsXnz5rj22mujrq4uSkpK4tFHH40hQ4ZERERdXV2r79T50Y9+FB9++GHMnDkzZs6c2bL/ggsuiMWLF+/8bwAAAPApnQ6diIhLL700Lr300nZ/9ul4Wb58eVdeAgAAoMu6dNc1AACA7kzoAAAAyRE6AABAcoQOAACQHKEDAAAkR+gAAADJEToAAEByhA4AAJAcoQMAACRH6AAAAMkROgAAQHKEDgAAkByhAwAAJEfoAAAAyRE6AABAcoQOAACQHKEDAAAkR+gAAADJEToAAEByhA4AAJAcoQMAACRH6AAAAMkROgAAQHKEDgAAkByhAwAAJEfoAAAAyRE6AABAcoQOAACQHKEDAAAkR+gAAADJEToAAEByhA4AAJAcoQMAACRH6AAAAMkROgAAQHKEDgAAkByhAwAAJEfoAAAAyRE6AABAcoQOAACQHKEDAAAkR+gAAADJEToAAEByhA4AAJAcoQMAACRH6AAAAMkROgAAQHKEDgAAkByhAwAAJEfoAAAAyRE6AABAcoQOAACQHKEDAAAkR+gAAADJEToAAEByhA4AAJAcoQMAACRH6AAAAMkROgAAQHKEDgAAkByhAwAAJEfoAAAAyRE6AABAcoQOAACQHKEDAAAkR+gAAADJEToAAEByhA4AAJAcoQMAACRH6AAAAMkROgAAQHKEDgAAkByhAwAAJEfoAAAAyRE6AABAcoQOAACQHKEDAAAkR+gAAADJEToAAEByhA4AAJAcoQMAACRH6AAAAMkROgAAQHKEDgAAkByhAwAAJEfoAAAAyRE6AABAcoQOAACQHKEDAAAkR+gAAADJEToAAEByhA4AAJAcoQMAACRH6AAAAMkROgAAQHKEDgAAkByhAwAAJEfoAAAAyRE6AABAcoQOAACQHKEDAAAkR+gAAADJEToAAEByuhQ6CxYsiGHDhkWvXr2itLQ0nnvuuW0eW1dXF+edd16MGDEievToEbNnz+7qrAAAAB3S6dBZunRpzJ49O6655ppYtWpVjBs3LiZMmBC1tbXtHt/Y2BgHHXRQXHPNNXHUUUft9MAAAAA70unQufnmm2PatGkxffr0GDVqVFRWVkZxcXEsXLiw3eOHDh0at956a0ydOjUKCgp2emAAAIAd6VTobN26Naqrq2P8+PGt9o8fPz5WrFixy4ZqbGyMhoaGVhsAAEBHdSp0Nm3aFE1NTVFUVNRqf1FRUWzYsGGXDTVv3rwoKCho2YqLi3fZcwMAAOnr0s0IcnJyWj3OZDJt9u2MOXPmRH19fcu2bt26XfbcAABA+vI6c/CAAQMiNze3zerNxo0b26zy7Iz8/PzIz8/fZc8HAADsXTq1otOzZ88oLS2NqqqqVvurqqpi7Nixu3QwAACArurUik5EREVFRUyZMiXKysqivLw8Fi1aFLW1tTFjxoyI+OhtZ+vXr48lS5a0nFNTUxMREe+//3786U9/ipqamujZs2ccfvjhu+a3AAAA+IROh87kyZNj8+bNce2110ZdXV2UlJTEo48+GkOGDImIj74g9NPfqTNmzJiWP1dXV8e9994bQ4YMibVr1+7c9AAAAO3odOhERFx66aVx6aWXtvuzxYsXt9mXyWS68jIAAABd0qW7rgEAAHRnQgcAAEiO0AEAAJIjdAAAgOQIHQAAIDlCBwAASI7QAQAAkiN0AACA5AgdAAAgOUIHAABIjtABAACSI3QAAIDkCB0AACA5QgcAAEiO0AEAAJIjdAAAgOQIHQAAIDlCBwAASI7QAQAAkiN0AACA5AgdAAAgOUIHAABIjtABAACSI3QAAIDkCB0AACA5QgcAAEiO0AEAAJIjdAAAgOQIHQAAIDlCBwAASI7QAQAAkiN0AACA5AgdAAAgOUIHAABIjtABAACSI3QAAIDkCB0AACA5QgcAAEiO0AEAAJIjdAAAgOQIHQAAIDlCBwAASI7QAQAAkiN0AACA5ORlewBgLzK3INsTdMzc+mxPAADsJKED8ClDr34k2yPs0Nr5E7M9AgB0a966BgAAJEfoAAAAyRE6AABAcoQOAACQHKEDAAAkR+gAAADJEToAAEByhA4AAJAcoQMAACRH6AAAAMkROgAAQHKEDgAAkByhAwAAJEfoAAAAyRE6AABAcoQOAACQnLxsDwAAAJ+ZuQXZnmDH5tZne4IkWdEBAACSI3QAAIDkCB0AACA5QgcAAEiO0AEAAJIjdAAAgOQIHQAAIDlCBwAASI7QAQAAkiN0AACA5AgdAAAgOUIHAABITl62BwAA2KPNLcj2BB0ztz7bE8BuZUUHAABIjtABAACSI3QAAIDk+IwOAMBeYOjVj2R7hB1aO39itkfIij3h2kTsedfHig4AAJAcoQMAACRH6AAAAMkROgAAQHKEDgAAkByhAwAAJEfoAAAAyRE6AABAcnxhKADsCeYWZHuCHZtbn+0JAFoIHQBgl9gTvt19T/tmd6DrvHUNAABIjtABAACSI3QAAIDkCB0AACA5QgcAAEiO0AEAAJIjdAAAgOQIHQAAIDlCBwAASI7QAQAAkiN0AACA5ORlewAAuom5BdmeYMfm1md7AgD2EEIHgD3G0KsfyfYIHbJ2/sRsjwCw1/PWNQAAIDldCp0FCxbEsGHDolevXlFaWhrPPffcdo9/5plnorS0NHr16hWf+9zn4oc//GGXhgUAAOiITofO0qVLY/bs2XHNNdfEqlWrYty4cTFhwoSora1t9/i33norzjrrrBg3blysWrUqvvWtb8Xll18ey5Yt2+nhAQAA2tPpz+jcfPPNMW3atJg+fXpERFRWVsYTTzwRCxcujHnz5rU5/oc//GEccsghUVlZGRERo0aNipUrV8ZNN90UZ599druv0djYGI2NjS2P6+s/+vBpQ0NDZ8f9bDRmsj1BhzQ3fpDtEXboM7mme8D12ROuTcRncH32gGsTsWdcH393ujfXp/vaW69NhOvTne0J1yai+/xb/OM5MpkdXNtMJzQ2NmZyc3MzDzzwQKv9l19+eebEE09s95xx48ZlLr/88lb7HnjggUxeXl5m69at7Z7z3e9+NxMRNpvNZrPZbDabzdbutm7duu22S6dWdDZt2hRNTU1RVFTUan9RUVFs2LCh3XM2bNjQ7vEffvhhbNq0KQYNGtTmnDlz5kRFRUXL4+bm5njnnXeif//+kZOT05mR91oNDQ1RXFwc69ati759+2Z7HD7BteneXJ/uy7Xp3lyf7s316b5cm87LZDLx3nvvxeDBg7d7XJduL/3p2MhkMtsNkPaOb2//x/Lz8yM/P7/Vvn79+nVhUvr27esvTTfl2nRvrk/35dp0b65P9+b6dF+uTecUFBTs8JhO3YxgwIABkZub22b1ZuPGjW1WbT42cODAdo/Py8uL/v37d+blAQAAOqRTodOzZ88oLS2NqqqqVvurqqpi7Nix7Z5TXl7e5vgnn3wyysrKYp999unkuAAAADvW6dtLV1RUxI9//OP4yU9+EqtXr44rr7wyamtrY8aMGRHx0edrpk6d2nL8jBkz4ve//31UVFTE6tWr4yc/+UnceeedcdVVV+2634I28vPz47vf/W6btwCSfa5N9+b6dF+uTffm+nRvrk/35dp8dnIymR3dl62tBQsWxPe///2oq6uLkpKSuOWWW+LEE0+MiIgLL7ww1q5dG8uXL285/plnnokrr7wyXn311Rg8eHB885vfbAkjAACAXa1LoQMAANCddfqtawAAAN2d0AEAAJIjdAAAgOQIHQAAIDlCJzHPPvtsTJo0KQYPHhw5OTnx0EMPZXsk/mbevHlx7LHHxv777x+FhYXxpS99KV5//fVsj0VELFy4MI488siWb6UuLy+Pxx57LNtjsQ3z5s2LnJycmD17drZHISLmzp0bOTk5rbaBAwdmeyz+Zv369fG1r30t+vfvH/vuu28cffTRUV1dne2xiIihQ4e2+buTk5MTM2fOzPZoyRA6idmyZUscddRRcfvtt2d7FD7lmWeeiZkzZ8aLL74YVVVV8eGHH8b48eNjy5Yt2R5tr3fwwQfH/PnzY+XKlbFy5co49dRT4x/+4R/i1VdfzfZofMrLL78cixYtiiOPPDLbo/AJo0ePjrq6upbtlVdeyfZIRMT//u//xgknnBD77LNPPPbYY/Haa6/FD37wg+jXr1+2RyM++t+zT/69qaqqioiIr3zlK1meLB152R6AXWvChAkxYcKEbI9BOx5//PFWj++6664oLCyM6urqlu+hIjsmTZrU6vG//du/xcKFC+PFF1+M0aNHZ2kqPu3999+P888/P+6444647rrrsj0On5CXl2cVpxu64YYbori4OO66666WfUOHDs3eQLRy0EEHtXo8f/78OPTQQ+Okk07K0kTpsaIDWVJfXx8REQceeGCWJ+GTmpqa4v77748tW7ZEeXl5tsfhE2bOnBkTJ06M008/Pduj8Clr1qyJwYMHx7Bhw+KrX/1qvPnmm9keiYh4+OGHo6ysLL7yla9EYWFhjBkzJu64445sj0U7tm7dGj/96U/joosuipycnGyPkwyhA1mQyWSioqIivvCFL0RJSUm2xyEiXnnlldhvv/0iPz8/ZsyYEQ8++GAcfvjh2R6Lv7n//vujuro65s2bl+1R+JTjjjsulixZEk888UTccccdsWHDhhg7dmxs3rw526Pt9d58881YuHBhDB8+PJ544omYMWNGXH755bFkyZJsj8anPPTQQ/Huu+/GhRdemO1RkuKta5AFs2bNit/85jfx/PPPZ3sU/mbEiBFRU1MT7777bixbtiwuuOCCeOaZZ8RON7Bu3bq44oor4sknn4xevXplexw+5ZNvlz7iiCOivLw8Dj300Lj77rujoqIii5PR3NwcZWVlcf3110dExJgxY+LVV1+NhQsXxtSpU7M8HZ905513xoQJE2Lw4MHZHiUpVnRgN7vsssvi4YcfjqeffjoOPvjgbI/D3/Ts2TMOO+ywKCsri3nz5sVRRx0Vt956a7bHIiKqq6tj48aNUVpaGnl5eZGXlxfPPPNM3HbbbZGXlxdNTU3ZHpFP6NOnTxxxxBGxZs2abI+y1xs0aFCb/7Nm1KhRUVtbm6WJaM/vf//7eOqpp2L69OnZHiU5VnRgN8lkMnHZZZfFgw8+GMuXL49hw4ZleyS2I5PJRGNjY7bHICJOO+20Nnfx+ud//ucYOXJkfPOb34zc3NwsTUZ7GhsbY/Xq1TFu3Lhsj7LXO+GEE9p8jcFvf/vbGDJkSJYmoj0f35xo4sSJ2R4lOUInMe+//3787ne/a3n81ltvRU1NTRx44IFxyCGHZHEyZs6cGffee2/813/9V+y///6xYcOGiIgoKCiI3r17Z3m6vdu3vvWtmDBhQhQXF8d7770X999/fyxfvrzNnfLIjv3337/NZ9n69OkT/fv39xm3buCqq66KSZMmxSGHHBIbN26M6667LhoaGuKCCy7I9mh7vSuvvDLGjh0b119/fZxzzjnxq1/9KhYtWhSLFi3K9mj8TXNzc9x1111xwQUXRF6ef5bvav6LJmblypVxyimntDz++P3RF1xwQSxevDhLUxHx0ZdSRkScfPLJrfbfddddPnyYZX/84x9jypQpUVdXFwUFBXHkkUfG448/HmeccUa2R4Nu7w9/+EOce+65sWnTpjjooIPi+OOPjxdffNGqQTdw7LHHxoMPPhhz5syJa6+9NoYNGxaVlZVx/vnnZ3s0/uapp56K2trauOiii7I9SpJyMplMJttDAAAA7EpuRgAAACRH6AAAAMkROgAAQHKEDgAAkByhAwAAJEfoAAAAyRE6AABAcoQOAACQHKEDAAAkR+gAAADJEToAAEBy/g8JNXXAW5CISAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 1000x800 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "\n",
    "unique, freq = np.unique(test_labels, return_counts=True)\n",
    "freq = list(map(lambda x: x / len(test_labels),freq))\n",
    "\n",
    "pred_freq = pred_prob.mean(axis=0)\n",
    "plt.figure(figsize=(10, 8))\n",
    "plt.bar(range(1, 8), pred_freq, width=0.4, align=\"edge\", label='prediction')\n",
    "plt.bar(range(1, 8), freq, width=-0.4, align=\"edge\", label='real')\n",
    "plt.ylim(0, 0.54)\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gp4uDyLmKxZ3"
   },
   "source": [
    "### Вопрос 5:\n",
    "* Какая прогнозируемая вероятность pred_freq класса под номером 3 (до 2 знаков после запятой)?"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
